{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1003\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),\n",
    "    BatchNormalization(epsilon=1e-05, axis=-1, center=True, scale=True)\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    if i == 5:\n",
    "        # std should be positive\n",
    "        weights.append(np.random.random(w.shape))\n",
    "    else:\n",
    "        weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_03'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/03.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_03\": {\"input\": {\"data\": [0.195612, -0.128132, -0.96626, 0.193375, 0.789956, 0.069255, -0.988089, 0.804359, 0.509039, -0.655792, 0.460058, -0.25375, -0.635374, -0.109318, -0.426266, -0.178438, -0.113165, -0.645789, -0.480558, -0.540853, -0.191122, 0.019685, -0.171947, -0.403521, -0.068903, -0.132053, -0.528253, 0.274125, -0.60827, 0.987533, 0.66134, 0.848859, -0.430572, -0.226527, 0.17207, -0.090301, 0.682272, -0.362781, -0.109492, -0.402323, -0.330768, -0.492729, -0.097343, 0.594172, -0.194273, -0.852785, 0.45843, -0.205545, 0.147448, -0.39129, -0.765637, -0.347263, -0.861972, 0.225595, 0.950601, 0.3874, 0.508384, 0.734384, 0.504576, -0.118712, 0.674019, 0.120606, 0.529139, 0.574269, -0.154725, -0.903845, -0.025097, -0.53235, -0.14829, 0.669766, -0.346025, -0.118921, -0.13139, -0.515193, 0.978336, 0.988066, 0.929787, 0.69695, 0.729349, -0.425137, 0.20707, 0.055498, -0.75746, -0.126905, -0.721594, -0.781276, -0.773981, -0.738916, -0.367574, 0.956481, 0.718737, 0.899696, 0.705012, -0.369647, 0.453032, 0.803767, 0.473406, 0.778174, -0.821354, 0.47794, 0.954554, -0.953411, -0.895929, 0.730168, 0.439185, -0.204223, -0.308631, -0.742228, 0.659115, 0.878338, 0.492655, 0.024586, -0.217403, -0.643025, 0.705032, -0.533673, 0.875713, -0.8079, -0.670755, -0.803671, 0.112563, -0.911985, 0.342553, 0.715974, 0.819998, -0.245518, -0.837444, 0.428726], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.195612, -0.128132, -0.96626, 0.193375, 0.789956, 0.069255, -0.988089, 0.804359, 0.509039, -0.655792, 0.460058, -0.25375, -0.635374, -0.109318, -0.426266, -0.178438, -0.113165, -0.645789, -0.480558, -0.540853, -0.191122, 0.019685, -0.171947, -0.403521, -0.068903, -0.132053, -0.528253, 0.274125, -0.60827, 0.987533, 0.66134, 0.848859, -0.430572, -0.226527, 0.17207, -0.090301, 0.682272, -0.362781, -0.109492, -0.402323, -0.330768, -0.492729, -0.097343, 0.594172, -0.194273, -0.852785, 0.45843, -0.205545, 0.147448, -0.39129, -0.765637, -0.347263, -0.861972, 0.225595, 0.950601, 0.3874, 0.508384, 0.734384, 0.504576, -0.118712, 0.674019, 0.120606, 0.529139, 0.574269, -0.154725, -0.903845, -0.025097, -0.53235, -0.14829, 0.669766, -0.346025, -0.118921], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.922097, 0.712992, 0.493001, 0.727856], \"shape\": [4]}, {\"data\": [0.318429, -0.858397, -0.059042, 0.68597], \"shape\": [4]}, {\"data\": [0.486255, -0.547151, 0.285068, 0.764711], \"shape\": [4]}, {\"data\": [0.0965, 0.594443, -0.987782, 0.431322], \"shape\": [4]}, {\"data\": [0.614003, 0.92974, 0.25491, 0.616435], \"shape\": [4]}], \"expected\": {\"data\": [0.44704, -0.017957, 0.169558, 0.387869, 0.541758, -1.061265, -0.026381, 0.387869, 0.44704, -0.643588, 0.169558, 1.087162, 0.555921, -0.179094, 0.16213, 0.387869, 0.666479, -0.017957, -0.062361, 0.683801, 0.44704, -0.017957, 0.05133, 0.945016, 0.44704, -1.051565, 0.156528, 1.096913, 0.44704, -0.232774, 0.033581, 0.723631, 0.44704, -2.162143, 0.013993, 1.267564, 0.44704, -0.16673, -0.026175, 0.553551, 0.44704, -1.284314, 0.135679, 0.387869, 0.44704, -1.597521, 0.169558, 0.995197, 0.44704, -0.627101, 0.169558, 0.387869, 0.857094, -0.017957, 0.06014, 1.412131, 0.44704, -0.017957, -0.095206, 1.545397, 0.44704, -0.463094, 0.157524, 0.471388, 0.44704, -1.246434, -0.053894, 1.694503, 0.44704, -0.017957, 0.027412, 1.561352, 0.44704, -0.094143, 0.110956, 0.997482, 0.44704, -0.11569, 0.162752, 0.387869, 0.44704, -1.431214, 0.169558, 0.556655, 0.44704, -0.017957, 0.169558, 1.458673, 0.512743, -0.017957, 0.009158, 1.213484, 0.44704, -0.698585, 0.151032, 1.04226, 0.44704, -0.075071, -0.035284, 0.387869, 0.44704, -4.176817, -0.03165, 0.740703, 0.789254, -0.017957, 0.169558, 0.678272, 0.44704, -0.017957, -0.101369, 0.387869, 0.44704, -0.017957, 0.169558, 0.387869, 0.472218, -0.017957, 0.169558, 2.361498, 0.44704, -1.596942, 0.169558, 2.340036, 0.44704, -2.324934, -0.128996, 1.040756, 0.44704, -0.017957, 0.169558, 0.387869, 0.44704, -2.042173, 0.087067, 0.387869, 0.547587, -0.017957, 0.169558, 1.179826, 0.727152, -0.200261, 0.169558, 0.72271], \"shape\": [6, 6, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
